{
 "metadata": {
  "name": "",
  "signature": "sha256:a5a54a4647dcc8ec013249e99b0ecb96ca8ad33c605834a4a188b58d9e758929"
 },
 "nbformat": 3,
 "nbformat_minor": 0,
 "worksheets": [
  {
   "cells": [
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "# Applying Bayes' theorem to iris classification\n",
      "\n",
      "Let's see if **Bayes' theorem** might be able to help us solve a **classification task**, namely predicting the species of an iris!"
     ]
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "## Preparing the data\n",
      "\n",
      "We'll load the iris data into a DataFrame, and **round up** all of the measurements to the next integer:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "from sklearn.datasets import load_iris\n",
      "import numpy as np\n",
      "import pandas as pd"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 1
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "# load the iris data\n",
      "iris = load_iris()\n",
      "\n",
      "# round up the measurements\n",
      "X = np.ceil(iris.data)\n",
      "\n",
      "# clean up column names\n",
      "col_names = [name[:-5].replace(' ', '_') for name in iris.feature_names]\n",
      "\n",
      "# read into pandas\n",
      "df = pd.DataFrame(X, columns=col_names)\n",
      "\n",
      "# create a list of species using iris.target and iris.target_names\n",
      "species = [iris.target_names[num] for num in iris.target]\n",
      "\n",
      "# add the species list as a new DataFrame column\n",
      "df['species'] = species"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 2
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "# print the head\n",
      "df.head()"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "html": [
        "<div style=\"max-width:1500px;overflow:auto;\">\n",
        "<table border=\"1\" class=\"dataframe\">\n",
        "  <thead>\n",
        "    <tr style=\"text-align: right;\">\n",
        "      <th></th>\n",
        "      <th>sepal_length</th>\n",
        "      <th>sepal_width</th>\n",
        "      <th>petal_length</th>\n",
        "      <th>petal_width</th>\n",
        "      <th>species</th>\n",
        "    </tr>\n",
        "  </thead>\n",
        "  <tbody>\n",
        "    <tr>\n",
        "      <th>0</th>\n",
        "      <td>6</td>\n",
        "      <td>4</td>\n",
        "      <td>2</td>\n",
        "      <td>1</td>\n",
        "      <td>setosa</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>1</th>\n",
        "      <td>5</td>\n",
        "      <td>3</td>\n",
        "      <td>2</td>\n",
        "      <td>1</td>\n",
        "      <td>setosa</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>2</th>\n",
        "      <td>5</td>\n",
        "      <td>4</td>\n",
        "      <td>2</td>\n",
        "      <td>1</td>\n",
        "      <td>setosa</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>3</th>\n",
        "      <td>5</td>\n",
        "      <td>4</td>\n",
        "      <td>2</td>\n",
        "      <td>1</td>\n",
        "      <td>setosa</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>4</th>\n",
        "      <td>5</td>\n",
        "      <td>4</td>\n",
        "      <td>2</td>\n",
        "      <td>1</td>\n",
        "      <td>setosa</td>\n",
        "    </tr>\n",
        "  </tbody>\n",
        "</table>\n",
        "</div>"
       ],
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 3,
       "text": [
        "   sepal_length  sepal_width  petal_length  petal_width species\n",
        "0             6            4             2            1  setosa\n",
        "1             5            3             2            1  setosa\n",
        "2             5            4             2            1  setosa\n",
        "3             5            4             2            1  setosa\n",
        "4             5            4             2            1  setosa"
       ]
      }
     ],
     "prompt_number": 3
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "## Deciding how to make a prediction\n",
      "\n",
      "Let's say that I had an **out-of-sample observation** with the following measurements: **7, 3, 5, 2**. I want to predict the species of this iris. How might I do that?\n",
      "\n",
      "We'll first examine all observations in the **training data** with those measurements:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "# show all observations with features: 7, 3, 5, 2\n",
      "df[(df.sepal_length==7) & (df.sepal_width==3) & (df.petal_length==5) & (df.petal_width==2)]"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "html": [
        "<div style=\"max-width:1500px;overflow:auto;\">\n",
        "<table border=\"1\" class=\"dataframe\">\n",
        "  <thead>\n",
        "    <tr style=\"text-align: right;\">\n",
        "      <th></th>\n",
        "      <th>sepal_length</th>\n",
        "      <th>sepal_width</th>\n",
        "      <th>petal_length</th>\n",
        "      <th>petal_width</th>\n",
        "      <th>species</th>\n",
        "    </tr>\n",
        "  </thead>\n",
        "  <tbody>\n",
        "    <tr>\n",
        "      <th>54</th>\n",
        "      <td>7</td>\n",
        "      <td>3</td>\n",
        "      <td>5</td>\n",
        "      <td>2</td>\n",
        "      <td>versicolor</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>58</th>\n",
        "      <td>7</td>\n",
        "      <td>3</td>\n",
        "      <td>5</td>\n",
        "      <td>2</td>\n",
        "      <td>versicolor</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>63</th>\n",
        "      <td>7</td>\n",
        "      <td>3</td>\n",
        "      <td>5</td>\n",
        "      <td>2</td>\n",
        "      <td>versicolor</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>68</th>\n",
        "      <td>7</td>\n",
        "      <td>3</td>\n",
        "      <td>5</td>\n",
        "      <td>2</td>\n",
        "      <td>versicolor</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>72</th>\n",
        "      <td>7</td>\n",
        "      <td>3</td>\n",
        "      <td>5</td>\n",
        "      <td>2</td>\n",
        "      <td>versicolor</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>73</th>\n",
        "      <td>7</td>\n",
        "      <td>3</td>\n",
        "      <td>5</td>\n",
        "      <td>2</td>\n",
        "      <td>versicolor</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>74</th>\n",
        "      <td>7</td>\n",
        "      <td>3</td>\n",
        "      <td>5</td>\n",
        "      <td>2</td>\n",
        "      <td>versicolor</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>75</th>\n",
        "      <td>7</td>\n",
        "      <td>3</td>\n",
        "      <td>5</td>\n",
        "      <td>2</td>\n",
        "      <td>versicolor</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>76</th>\n",
        "      <td>7</td>\n",
        "      <td>3</td>\n",
        "      <td>5</td>\n",
        "      <td>2</td>\n",
        "      <td>versicolor</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>77</th>\n",
        "      <td>7</td>\n",
        "      <td>3</td>\n",
        "      <td>5</td>\n",
        "      <td>2</td>\n",
        "      <td>versicolor</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>87</th>\n",
        "      <td>7</td>\n",
        "      <td>3</td>\n",
        "      <td>5</td>\n",
        "      <td>2</td>\n",
        "      <td>versicolor</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>91</th>\n",
        "      <td>7</td>\n",
        "      <td>3</td>\n",
        "      <td>5</td>\n",
        "      <td>2</td>\n",
        "      <td>versicolor</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>97</th>\n",
        "      <td>7</td>\n",
        "      <td>3</td>\n",
        "      <td>5</td>\n",
        "      <td>2</td>\n",
        "      <td>versicolor</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>123</th>\n",
        "      <td>7</td>\n",
        "      <td>3</td>\n",
        "      <td>5</td>\n",
        "      <td>2</td>\n",
        "      <td>virginica</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>126</th>\n",
        "      <td>7</td>\n",
        "      <td>3</td>\n",
        "      <td>5</td>\n",
        "      <td>2</td>\n",
        "      <td>virginica</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>127</th>\n",
        "      <td>7</td>\n",
        "      <td>3</td>\n",
        "      <td>5</td>\n",
        "      <td>2</td>\n",
        "      <td>virginica</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>146</th>\n",
        "      <td>7</td>\n",
        "      <td>3</td>\n",
        "      <td>5</td>\n",
        "      <td>2</td>\n",
        "      <td>virginica</td>\n",
        "    </tr>\n",
        "  </tbody>\n",
        "</table>\n",
        "</div>"
       ],
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 4,
       "text": [
        "     sepal_length  sepal_width  petal_length  petal_width     species\n",
        "54              7            3             5            2  versicolor\n",
        "58              7            3             5            2  versicolor\n",
        "63              7            3             5            2  versicolor\n",
        "68              7            3             5            2  versicolor\n",
        "72              7            3             5            2  versicolor\n",
        "73              7            3             5            2  versicolor\n",
        "74              7            3             5            2  versicolor\n",
        "75              7            3             5            2  versicolor\n",
        "76              7            3             5            2  versicolor\n",
        "77              7            3             5            2  versicolor\n",
        "87              7            3             5            2  versicolor\n",
        "91              7            3             5            2  versicolor\n",
        "97              7            3             5            2  versicolor\n",
        "123             7            3             5            2   virginica\n",
        "126             7            3             5            2   virginica\n",
        "127             7            3             5            2   virginica\n",
        "146             7            3             5            2   virginica"
       ]
      }
     ],
     "prompt_number": 4
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "# count the species for these observations\n",
      "df[(df.sepal_length==7) & (df.sepal_width==3) & (df.petal_length==5) & (df.petal_width==2)].species.value_counts()"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 5,
       "text": [
        "versicolor    13\n",
        "virginica      4\n",
        "dtype: int64"
       ]
      }
     ],
     "prompt_number": 5
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "# count the species for all observations\n",
      "df.species.value_counts()"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 6,
       "text": [
        "setosa        50\n",
        "versicolor    50\n",
        "virginica     50\n",
        "dtype: int64"
       ]
      }
     ],
     "prompt_number": 6
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Okay, so how might **Bayes' theorem** help us here?\n",
      "\n",
      "Let's frame this as a **conditional probability**: What is the probability of some particular class, given the measurements 7352?\n",
      "\n",
      "$$P(class | 7352)$$\n",
      "\n",
      "We could calculate this conditional probability for **each of the three classes**, and then predict the class with the **highest probability**:\n",
      "\n",
      "$$P(setosa | 7352)$$\n",
      "$$P(versicolor | 7352)$$\n",
      "$$P(virginica | 7352)$$"
     ]
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "## Calculating the probability of each class\n",
      "\n",
      "Let's start with **versicolor**:\n",
      "\n",
      "$$P(versicolor | 7352) = \\frac {P(7352 | versicolor) \\times P(versicolor)} {P(7352)}$$"
     ]
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "We'll calculate each of the terms on the right side of the equation:\n",
      "\n",
      "$$P(7352 | versicolor) = \\frac {13} {50} = 0.26$$\n",
      "\n",
      "$$P(versicolor) = \\frac {50} {150} = 0.33$$\n",
      "\n",
      "$$P(7352) = \\frac {17} {150} = 0.11$$"
     ]
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Therefore, Bayes' theorem says the **probability of versicolor given these measurements** is:\n",
      "\n",
      "$$P(versicolor | 7352) = \\frac {0.26 \\times 0.33} {0.11} = 0.76$$"
     ]
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Let's repeat this process for the other two classes, though we already know that versicolor will have the highest probability:\n",
      "\n",
      "$$P(virginica | 7352) = \\frac {0.08 \\times 0.33} {0.11} = 0.24$$\n",
      "\n",
      "$$P(setosa | 7352) = \\frac {0 \\times 0.33} {0.11} = 0$$"
     ]
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "In summary, we framed a **classification problem** as three conditional probability equations, we used **Bayes' theorem** to solve those equations, and then we made a **prediction** by choosing the class with the highest conditional probability."
     ]
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "## Adjusting the data\n",
      "\n",
      "Let's make some hypothetical adjustments to the data, to demonstrate how Bayes' theorem actually makes intuitive sense:\n",
      "\n",
      "Pretend that **more of the existing versicolors were 7352:**\n",
      "\n",
      "- $P(7352|versicolor)$ would increase, thus increasing the numerator.\n",
      "- It would make sense that given an iris with 7352, the probability of it being a versicolor would also increase.\n",
      "\n",
      "Pretend that **most of the existing irises were versicolor:**\n",
      "\n",
      "- $P(versicolor)$ would increase, thus increasing the numerator.\n",
      "- It would make sense that the probability of any iris being a versicolor (regardless of measurements) would also increase.\n",
      "\n",
      "Pretend that **17 of the setosas were 7352:**\n",
      "\n",
      "- $P(7352)$ would double, thus doubling the denominator.\n",
      "- It would make sense that given an iris with 7352, the probability of it being a versicolor would be cut in half."
     ]
    }
   ],
   "metadata": {}
  }
 ]
}